/*
%          [L.L, L.d, skip, diagadd] = blkchol(L,ADA [,pars [,absd]])
% BLKCHOL  Computes sparse lower-triangular Cholesky factor L,
%            L*L' = P(perm,perm)
%   Input Parameter L is typically generated by symbchol.
%   Parameters: pars.canceltol and pars.maxu.
%   Optional: absd to force adding diag if drops below canceltol*absd.
%
%   There are important differences with standard CHOL(P(L.perm,L.perm))':
%
%   -  BLKCHOL uses the supernodal partition XSUPER, possibly splitted
%    by SPLIT, to use dense linear algebra on dense subblocks.
%    Much faster than CHOL.
%
%   -  BLKCHOL never fails. It only sees the lower triangular part of P;
%    if during the elimination, a diagonal entry becomes negative
%    (due to massive cancelation), the corresponding d[k] is set to 0.
%    If d[k] suffers from cancelation and norm(L(:,k)) becomes too big, then
%    the column is skipped, and listed in (skip, Lskip).
%
% SEE ALSO sparchol, fwblkslv, bwblkslv   

% This file is part of SeDuMi 1.1 by Imre Polik and Oleksandr Romanko
% Copyright (C) 2005 McMaster University, Hamilton, CANADA  (since 1.1)
%
% Copyright (C) 2001 Jos F. Sturm (up to 1.05R5)
%   Dept. Econometrics & O.R., Tilburg University, the Netherlands.
%   Supported by the Netherlands Organization for Scientific Research (NWO).
%
% Affiliation SeDuMi 1.03 and 1.04Beta (2000):
%   Dept. Quantitative Economics, Maastricht University, the Netherlands.
%
% Affiliations up to SeDuMi 1.02 (AUG1998):
%   CRL, McMaster University, Canada.
%   Supported by the Netherlands Organization for Scientific Research (NWO).
%
% This program is free software; you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation; either version 2 of the License, or
% (at your option) any later version.
%
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with this program; if not, write to the Free Software
% Foundation, Inc.,  51 Franklin Street, Fifth Floor, Boston, MA
% 02110-1301, USA


*/
#include <string.h>
#include "mex.h"
#include "blksdp.h"

#define L_OUT    myplhs[0]
#define D_OUT    myplhs[1]
#define SKIP_OUT  myplhs[2]
#define DIAGADD_OUT  myplhs[3]
#define NPAROUT 4

#define L_IN      prhs[0]
#define P_IN      prhs[1]
#define NPARINMIN 2
#define PARS_IN   prhs[2]
#define ABSD_IN   prhs[3]
#define NPARIN 4

/* ------------------------------------------------------------
   PROTOTYPES:
   ------------------------------------------------------------ */
mwIndex blkLDL(const mwIndex neqns, const mwIndex nsuper, const mwIndex *xsuper,
           const mwIndex *snode,  const mwIndex *xlindx, const mwIndex *lindx,
           double *lb,
           const mwIndex *ljc, double *lpr, double *d, const mwIndex *perm,
           const double ub, const double maxu, mwIndex *skipIr,
           mwIndex iwsiz, mwIndex *iwork, mwIndex fwsiz, double *fwork);

/* ============================================================
   SUBROUTINES:
   ============================================================*/
/* ------------------------------------------------------------
   PERMUTEP - Let L = tril(P(perm,perm))
   INPUT
     Ljc, Lir - sparsity structure of output matrix L = tril(P(perm,perm)).
     Pjc, Pir, Ppr - Input matrix, before ordering.
     perm     - length m pivot ordering.
     m        - order: P is m x m.
   WORKING ARRAY
      Pj  - Length m float work array.
   IMPORTANT: L, P and PERM in C style.
   ------------------------------------------------------------ */
void permuteP(const mwIndex *Ljc,const mwIndex *Lir,double *Lpr,
              const mwIndex *Pjc,const mwIndex *Pir,const double *Ppr,
              const mwIndex *perm, double *Pj, const mwIndex m)
{
  mwIndex j,inz,jcol;
/* ------------------------------------------------------------
   Let Pj = all-0
   ------------------------------------------------------------ */
  fzeros(Pj,m);
/* ------------------------------------------------------------
   For each column j, let
    Pj(:) = P(:,PERM(j))   and    L(:,j) = Pj(PERM(:))  (L sparse)
   ------------------------------------------------------------ */
  for(j = 0; j < m; j++){
    jcol = perm[j];
    for(inz = Pjc[jcol]; inz < Pjc[jcol+1]; inz++)
      Pj[Pir[inz]] = Ppr[inz];
    for(inz = Ljc[j]; inz < Ljc[j+1]; inz++)
      Lpr[inz] = Pj[perm[Lir[inz]]];
/* ------------------------------------------------------------
   Let Pj = all-0
   ------------------------------------------------------------ */
    for(inz = Pjc[jcol]; inz < Pjc[jcol+1]; inz++)
      Pj[Pir[inz]] = 0.0;
  }
}

/* ************************************************************
   SPCHOL - calls the block cholesky blkLDL.
   INPUT:
      m       - Order of L: L is m x m, ne.At is N x m.
      nsuper  - Number of supernodes (blocks).
      xsuper  - Length nsuper+1: first simple-node of each supernode
      snode   - Length neqns: snode(node) is the supernode containing "node".
      ljc     - Length neqns+1: start of the columns of L.
      abstol  - minimum diagonal value. If 0, then no absolute threshold.
      canceltol  - Force d >= canceltol * orgd (by adding low-rank diag).
      maxu    - Maximal allowed max(abs(L(:,k)))..
         If L gets to big in these columns, we skip the pivots.
      iwsiz, fwsiz - size of integer and floating-point working storage.
               See "WORKING ARRAYS" for required amount.
   UPDATED:
      lindx   - row indices. On INPUT: for each column (by ljc),
          on OUTPUT: for each supernode (by xlindx).
      Lpr     - On input, contains tril(X), on output
	  such that   X = L*diag(d)*L'.
      lb    - Length neqns. INPUT:  diag entries BEFORE cancelation;
          OUTPUT: lb(skipIr) are values of low rank diag. matrix that is
          added before factorization.
   OUTPUT
      xlindx  - Length nsuper+1: Start of sparsity structure in lindx,
              for each supernode (all simple nodes in a supernode have the
              same nonzero-structure).
      snode  - Length m: snode(node) is the supernode containing "node".
      d      - length neqns vector, diagonal in L*diag(d)*L'.
      skipIr - length nskip (<= neqns) array. skipIr(1:nskip) lists the
        columns that have been skipped in the Cholesky. d[skipIr] = 0.
   WORKING ARRAYS:
      iwork  - Length iwsiz working array; iwsiz = 2*m + 2 * nsuper.
      fwork  - Length fwsiz working vector; fwsiz = L.tmpsiz.
   RETURNS - nskip (<=neqns), number of skipped nodes. Length of skipIr.
   *********************************************************************** */
mwIndex spchol(const mwIndex m, const mwIndex nsuper, const mwIndex *xsuper,
           mwIndex *snode,	mwIndex *xlindx, mwIndex *lindx, double *lb,
           const mwIndex *ljc, double *lpr, double *d, const mwIndex *perm,
           const double abstol,
           const double canceltol, const double maxu, mwIndex *skipIr,
           mwIndex *pnadd,
           const mwIndex iwsiz, mwIndex *iwork, const mwIndex fwsiz, double *fwork)
{
  mwIndex jsup,j,ix,jcol,collen, nskip, nadd;
  double ub, dj;
  
/* ------------------------------------------------------------
   Let ub = max(diag(L)) / maxu^2
   ------------------------------------------------------------ */
  ub = 0.0;
  for(j = 0; j < m; j++)
    if((dj = lpr[ljc[j]]) > ub)
      ub = dj;
  ub /= SQR(maxu);
/* ------------------------------------------------------------
   Let lb = MAX(abstol, canceltol * lbIN), where lbIN is diag(L),
   or those quantities before being affected by cancelation.
   ------------------------------------------------------------ */
  for(j = 0; j < m; j++)
    if((dj = canceltol * lb[j]) > abstol)
      lb[j] = dj;
    else
      lb[j] = abstol;
/* ------------------------------------------------------------
   SNODE: map each column to the supernode containing it
   ------------------------------------------------------------ */
  j = xsuper[0];
  for(jsup = 0; jsup < nsuper; jsup++){
    while(j < xsuper[jsup + 1])
      snode[j++] = jsup;
  }
/* ------------------------------------------------------------
   COMPRESS SUBSCRIPTS:
    Let (xlindx,lindx) = ljc(xsuper(:)), i.e store only once
    for each snode, instead of once per column.
   ------------------------------------------------------------ */
  for(ix = 0, jsup = 0; jsup < nsuper; jsup++){
    xlindx[jsup] = ix;
    jcol = xsuper[jsup];
    collen = ljc[jcol+1] - ljc[jcol];
    memmove(lindx + ix, lindx + ljc[jcol], collen * sizeof(mwIndex));
    ix += collen;
  }
  xlindx[nsuper] = ix;
/* ------------------------------------------------------------
   Do the block sparse Cholesky L*D*L'
   ------------------------------------------------------------ */
  nskip = blkLDL(m, nsuper, xsuper, snode, xlindx, lindx, lb,
                ljc, lpr, d, perm,
                ub, maxu, skipIr, iwsiz, iwork, fwsiz, fwork);
  if(nskip == (mwIndex)-1)
    return nskip;
/* ------------------------------------------------------------
   Let iwork = diag-adding indices. Viz. where d(skipIr)>0.0.
   Let skipIr = skipIr except diag-adding indices. Hence d(skipIr)=0.
   ------------------------------------------------------------ */
  if(iwsiz < nskip)
    return (mwIndex)-1;
  ix = 0;
  nadd = 0;
  for(j = 0; j < nskip; j++){
    jsup = skipIr[j];
    if(d[jsup] > 0.0)
      iwork[nadd++] = jsup;       /* diagonal adding */
    else
      skipIr[ix++] = jsup;          /* pivot skipping */
  }
  *pnadd = nadd;
  return ix;
}

/* ============================================================
   MAIN: MEXFUNCTION
   ============================================================ */
/* ************************************************************
   PROCEDURE mexFunction - Entry for Matlab
   ************************************************************ */
void mexFunction(int nlhs, mxArray *plhs[],
  int nrhs, const mxArray *prhs[])
{
  const mxArray *L_FIELD;
  mxArray *myplhs[NPAROUT];
  mwIndex    m, i, j, iwsiz, nsuper, tmpsiz, fwsiz, nskip, nadd, m1;
  double *fwork, *d, *skipPr, *orgd;
  const double *permPr,*xsuperPr,*Ppr,*absd;
  mwIndex    *perm, *snode, *xsuper, *iwork, *xlindx, *skip, *skipJc;
  const mwIndex *LINir, *Pjc, *Pir;
  double canceltol, maxu, abstol;
  jcir   L;
  char useAbsd, useDelay;
/* ------------------------------------------------------------
   Check for proper number of arguments
   blkchol(L,P, pars,absd) with nparinmin=2.
   ------------------------------------------------------------ */
  mxAssert(nrhs >= NPARINMIN, "blkchol requires more input arguments");
  mxAssert(nlhs <= NPAROUT, "blkchol produces less output arguments");
/* ------------------------------------------------------------
   Get input matrix P to be factored.
   ------------------------------------------------------------ */
  m = mxGetM(P_IN);
  mxAssert( m == mxGetN(P_IN), "P must be square");
  mxAssert(mxIsSparse(P_IN), "P must be sparse");
  Pjc    = mxGetJc(P_IN);
  Pir    = mxGetIr(P_IN);
  Ppr    = mxGetPr(P_IN);
/* ------------------------------------------------------------
   Disassemble block Cholesky structure L
   ------------------------------------------------------------ */
  mxAssert(mxIsStruct(L_IN), "Parameter `L' should be a structure.");
  L_FIELD = mxGetField(L_IN,(mwIndex)0,"perm");       /* L.perm */
  mxAssert( L_FIELD != NULL, "Missing field L.perm.");
  mxAssert(m == mxGetM(L_FIELD) * mxGetN(L_FIELD), "perm size mismatch");
  permPr = mxGetPr(L_FIELD);
  L_FIELD = mxGetField(L_IN,(mwIndex)0,"L");         /* L.L */
  mxAssert( L_FIELD != NULL, "Missing field L.L.");
  mxAssert( m == mxGetM(L_FIELD) && m == mxGetN(L_FIELD), "Size L.L mismatch.");
  mxAssert(mxIsSparse(L_FIELD), "L.L should be sparse.");
  L.jc = mxGetJc(L_FIELD);
  LINir = mxGetIr(L_FIELD);
  L_FIELD = mxGetField(L_IN,(mwIndex)0,"xsuper");       /* L.xsuper */
  mxAssert( L_FIELD != NULL, "Missing field L.xsuper.");
  nsuper = mxGetM(L_FIELD) * mxGetN(L_FIELD) - 1;
  mxAssert( nsuper <= m, "Size L.xsuper mismatch.");
  xsuperPr = mxGetPr(L_FIELD);
  L_FIELD = mxGetField(L_IN,(mwIndex)0,"tmpsiz");         /* L.tmpsiz */
  mxAssert( L_FIELD != NULL, "Missing field L.tmpsiz.");
  tmpsiz   = (mwIndex) mxGetScalar(L_FIELD);
/* ------------------------------------------------------------
   Disassemble pars structure: canceltol, maxu
   ------------------------------------------------------------ */
  canceltol = 1E-12;           /* supply with defaults */
  maxu = 5E2;
  abstol = 1E-20;
  useAbsd = 0;
  useDelay = 0;
  if(nrhs >= NPARINMIN + 1){       /* 3rd argument = pars */
    mxAssert(mxIsStruct(PARS_IN), "Parameter `pars' should be a structure.");
    if( (L_FIELD = mxGetField(PARS_IN,(mwIndex)0,"canceltol")) != NULL)
      canceltol  = mxGetScalar(L_FIELD);  /* pars.canceltol */
    if( (L_FIELD = mxGetField(PARS_IN,(mwIndex)0,"maxu")) != NULL)
      maxu = mxGetScalar(L_FIELD);  /* pars.maxu */
    if( (L_FIELD = mxGetField(PARS_IN,(mwIndex)0,"abstol")) != NULL){
      abstol = mxGetScalar(L_FIELD);  /* pars.abstol */
      abstol = MAX(abstol, 0.0);
    }
    if( (L_FIELD = mxGetField(PARS_IN,(mwIndex)0,"delay")) != NULL)
      useDelay = (char) mxGetScalar(L_FIELD);  /* pars.delay */
/* ------------------------------------------------------------
   Get optional vector absd
   ------------------------------------------------------------ */
    if(nrhs >= NPARIN){
      useAbsd = 1;
      absd = mxGetPr(ABSD_IN);
      mxAssert(m == mxGetM(ABSD_IN) * mxGetN(ABSD_IN), "absd size mismatch");
    }
  }
/* ------------------------------------------------------------
   Create sparse output matrix L(m x m).
   ------------------------------------------------------------ */
  L_OUT = mxCreateSparse(m,m, L.jc[m],mxREAL);
  L.ir  = mxGetIr(L_OUT);
  L.pr  = mxGetPr(L_OUT);
  memcpy(mxGetJc(L_OUT), L.jc, (m+1) * sizeof(mwIndex));
  memcpy(L.ir, LINir, L.jc[m] * sizeof(mwIndex));
/* ------------------------------------------------------------
   Create ouput vector d(m).
   ------------------------------------------------------------ */
  D_OUT = mxCreateDoubleMatrix(m,(mwSize)1,mxREAL);
  d     = mxGetPr(D_OUT);
/* ------------------------------------------------------------
   Compute required sizes of working arrays:
   iwsiz = 2*(m + nsuper).
   fwsiz = tmpsiz.
   ------------------------------------------------------------ */
  iwsiz = MAX(2*(m+nsuper), 1);
  fwsiz = MAX(tmpsiz, 1);
/* ------------------------------------------------------------
   Allocate working arrays:
   integer: perm(m), snode(m), xsuper(nsuper+1),
      iwork(iwsiz), xlindx(m+1), skip(m),
   double: orgd(m), fwork(fwsiz).
   ------------------------------------------------------------ */
  m1 = MAX(m,1);                  /* avoid alloc to 0 */
  perm      = (mwIndex *) mxCalloc(m1,sizeof(mwIndex)); 
  snode     = (mwIndex *) mxCalloc(m1,sizeof(mwIndex)); 
  xsuper    = (mwIndex *) mxCalloc(nsuper+1,sizeof(mwIndex));
  iwork     = (mwIndex *) mxCalloc(iwsiz,sizeof(mwIndex));
  xlindx    = (mwIndex *) mxCalloc(m+1,sizeof(mwIndex));
  skip      = (mwIndex *) mxCalloc(m1, sizeof(mwIndex));
  orgd    = (double *) mxCalloc(m1,sizeof(double)); 
  fwork   = (double *) mxCalloc(fwsiz,sizeof(double)); 
/* ------------------------------------------------------------
   Convert PERM, XSUPER to integer and C-Style
   ------------------------------------------------------------ */
  for(i = 0; i < m; i++){
    j = (mwIndex) permPr[i];
    mxAssert(j>0,"");
    perm[i] = --j;
  }
  for(i = 0; i <= nsuper; i++){
    j =  (mwIndex) xsuperPr[i];
    mxAssert(j>0,"");
    xsuper[i] = --j;
  }
/* ------------------------------------------------------------
   Let L = tril(P(PERM,PERM)), uses orgd(m) as temp working storage.
   ------------------------------------------------------------ */
  permuteP(L.jc,L.ir,L.pr, Pjc,Pir,Ppr, perm, orgd, m);
/* ------------------------------------------------------------
   If no orgd has been supplied, take orgd = diag(L on input)
   Otherwise, let orgd = absd(perm).
   ------------------------------------------------------------ */
  if(useAbsd)
    for(j = 0; j < m; j++)
      orgd[j] = absd[perm[j]];
  else
    for(j = 0; j < m; j++)
      orgd[j] = L.pr[L.jc[j]];
/* ------------------------------------------------------------
   Create "snode" and "xlindx"; change L.ir to the compact subscript
   array (with xlindx), and do BLOCK SPARSE CHOLESKY.
   ------------------------------------------------------------ */
  nskip = spchol(m, nsuper, xsuper, snode, xlindx,
                 L.ir, orgd, L.jc, L.pr, d, perm, abstol,
                 canceltol, maxu, skip, &nadd, iwsiz, iwork, fwsiz, fwork);
  mxAssert(nskip >= 0, "Insufficient workspace in pblkchol");
/* ------------------------------------------------------------
   Copy original row-indices from LINir to L.ir.
   ------------------------------------------------------------ */
  memcpy(L.ir, LINir, L.jc[m] * sizeof(mwIndex));
/* ------------------------------------------------------------
   Create output matrices skip = sparse([],[],[],m,1,nskip),
   diagadd = sparse([],[],[],m,1,nadd),
   ------------------------------------------------------------ */
  SKIP_OUT = mxCreateSparse(m,(mwSize)1, MAX(1,nskip),mxREAL);
  memcpy(mxGetIr(SKIP_OUT), skip, nskip * sizeof(mwIndex));
  skipJc = mxGetJc(SKIP_OUT);
  skipJc[0] = 0; skipJc[1] = nskip;
  skipPr   = mxGetPr(SKIP_OUT);
/* ------------------------------------------------------------
   useDelay = 1 then L(:,i) is i-th column before ith pivot; useful
     for pivot-delaying strategy. (Fwslv(L, L(:,i)) still required.)
   ------------------------------------------------------------ */
  if(useDelay == 1)
    for(j = 0; j < nskip; j++)
      skipPr[j] = 1.0;
  else
    for(j = 0; j < nskip; j++){
      i = skip[j];
      skipPr[j] = L.pr[L.jc[i]];             /* Set skipped l(:,i)=ei. */
      L.pr[L.jc[i]] = 1.0;
      fzeros(L.pr+L.jc[i]+1,L.jc[i+1]-L.jc[i]-1);
    }
  DIAGADD_OUT = mxCreateSparse(m,(mwSize)1, MAX(1,nadd),mxREAL);
  memcpy(mxGetIr(DIAGADD_OUT), iwork, nadd * sizeof(mwIndex));
  skipJc = mxGetJc(DIAGADD_OUT);
  skipJc[0] = 0; skipJc[1] = nadd;
  skipPr   = mxGetPr(DIAGADD_OUT);
  for(j = 0; j < nadd; j++)
    skipPr[j] = orgd[iwork[j]];
/* ------------------------------------------------------------
   Release working arrays.
   ------------------------------------------------------------ */
  mxFree(fwork);
  mxFree(orgd);
  mxFree(skip);
  mxFree(xlindx);
  mxFree(iwork);
  mxFree(xsuper);
  mxFree(snode);
  mxFree(perm);
/* ------------------------------------------------------------
   Copy requested output parameters (at least 1), release others.
   ------------------------------------------------------------ */
  i = MAX(nlhs, 1);
  memcpy(plhs,myplhs, i * sizeof(mxArray *));
  for(; i < NPAROUT; i++)
    mxDestroyArray(myplhs[i]);
}
